{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0d6ccf45-7c3c-4c3e-a5e1-49794c584902",
   "metadata": {},
   "source": [
    "# 在agent与tool之间共享记忆\n",
    "\n",
    "- 自定义一个工具用来LLMChain来总结内容\n",
    "- 使用readonlymemory来共享记忆\n",
    "- 观察共享与不共享的区别"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a18e1cb3-a7f2-4cd1-a04d-e40fbd29b9c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.agents import initialize_agent, Tool, AgentType\n",
    "from langchain.memory import ConversationBufferMemory, ReadOnlySharedMemory\n",
    "from langchain.llms import OpenAI\n",
    "from langchain.chains import LLMChain\n",
    "from langchain.prompts import PromptTemplate,MessagesPlaceholder\n",
    "from langchain.utilities import SerpAPIWrapper\n",
    "\n",
    "llm = OpenAI(\n",
    "    temperature = 0,\n",
    "    model=\"gpt-3.5-turbo-instruct\",\n",
    ")\n",
    "\n",
    "# 创建一条链来总结对话\n",
    "template = \"\"\"以下是一段AI机器人和人类的对话:\n",
    "{chat_history}\n",
    "根据输入和上面的对话记录写一份对话总结.\n",
    "输入: {input}\"\"\"\n",
    "prompt = PromptTemplate(\n",
    "    input_variables=[\"input\",\"chat_history\"],\n",
    "    template=template,\n",
    ")\n",
    "\n",
    "memory = ConversationBufferMemory(\n",
    "    memory_key=\"chat_history\",\n",
    "    return_messages=True,\n",
    ")\n",
    "readonlymemory = ReadOnlySharedMemory(memory=memory)\n",
    "summary_chain = LLMChain(\n",
    "    llm=llm,\n",
    "    prompt=prompt,\n",
    "    verbose=True,\n",
    "    memory=readonlymemory,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f609dd5-38c4-48f6-a090-15d2206228ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 构建工具\n",
    "import os\n",
    "os.environ[\"SERPAPI_API_KEY\"] = \"f265b8d9834ed7692cba6db6618e2a8a9b24ed6964c457296a2626026e8ed594\"\n",
    "\n",
    "#搜索工具\n",
    "search = SerpAPIWrapper()\n",
    "\n",
    "#总结工具\n",
    "def SummaryChainFun(history):\n",
    "    print(\"\\n==============总结链开始运行==============\")\n",
    "    print(\"输入历史: \",history)\n",
    "    summary_chain.run(history)\n",
    "\n",
    "tools = [\n",
    "    Tool(\n",
    "        name=\"Search\",\n",
    "        func=search.run,\n",
    "        description=\"当需要了解实时的信息或者你不知道的事时候可以使用搜索工具\",\n",
    "    ),\n",
    "    Tool(\n",
    "        name=\"Summary\",\n",
    "        func=SummaryChainFun, # 调用总结工具会使用到memory\n",
    "        description=\"当你被要求总结一段对话的时候可以使用这个工具，工具输入必须为字符串，只在必要时使用\",\n",
    "    ),\n",
    "]\n",
    "print(tools)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58f483de-df2d-4276-ae85-46702a5a436d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 创建记忆组件\n",
    "memory = ConversationBufferMemory(\n",
    "    memory_key=\"chat_history\",\n",
    "    return_messages=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b73d848-f610-457f-91bb-153983e6eba7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 创建agent\n",
    "agent_chain = initialize_agent(\n",
    "    tools,\n",
    "    llm,\n",
    "    agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,\n",
    "    verbose=True,\n",
    "    handle_parsing_errors=True,\n",
    "    memory=memory,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf449749-3555-4392-83b6-56fb9648544d",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(agent_chain.agent.llm_chain.prompt.template)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6b27f2f-2ab5-4f9b-8c88-8844bf76f978",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 对提示词，前缀和后缀进行覆盖\n",
    "prefix = \"\"\"Have a conversation with a human, answering the following questions as best you can. You have access to the following tools:\"\"\"\n",
    "suffix = \"\"\"Begin!\"\n",
    "{chat_history}\n",
    "Question: {input}\n",
    "{agent_scratchpad}\"\"\"\n",
    "\n",
    "agent_chain = initialize_agent(\n",
    "    tools,\n",
    "    llm,\n",
    "    agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,\n",
    "    verbose=True,\n",
    "    handle_parsing_errors=True,\n",
    "    agent_kwargs={\n",
    "        \"prefix\":prefix,\n",
    "        \"suffix\":suffix,\n",
    "        \"agent_scratchpad\":MessagesPlaceholder(\"agent_scratchpad\"),\n",
    "        \"chat_history\":MessagesPlaceholder(\"chat_history\"),\n",
    "        \"input\":MessagesPlaceholder(\"input\"),\n",
    "    },\n",
    "    memory=memory,\n",
    ")\n",
    "print(agent_chain.agent.llm_chain.prompt.template)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f63e4bb-b62b-4749-824e-a23f8f77aa05",
   "metadata": {},
   "outputs": [],
   "source": [
    "agent_chain.run(input=\"美国第45任总统是谁?\")\n",
    "print(agent_chain.memory.buffer)\n",
    "agent_chain.run(input=\"我们都聊了什么？\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (langchain)",
   "language": "python",
   "name": "langchain-env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
